{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Graph Convolutional Neural Network (GCN)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here is a complete [Tensorflow](https://www.tensorflow.org/) implementation of a two-layer graph convolutional neural network (GCN) for link prediction and it follows the GCN formulation as presented in [Kipf et al., ICLR 2017](https://arxiv.org/pdf/1609.02907.pdf). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import time\n",
    "\n",
    "import numpy as np\n",
    "import scipy.sparse as sp\n",
    "\n",
    "import networkx as nx\n",
    "import tensorflow as tf\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from sklearn.metrics import average_precision_score\n",
    "import os\n",
    "os.environ['VISIBLE_CUDA_DEVICES'] = '1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Set random seed\n",
    "seed = 123\n",
    "np.random.seed(seed)\n",
    "tf.set_random_seed(seed)\n",
    "\n",
    "# Settings\n",
    "flags = tf.app.flags\n",
    "FLAGS = flags.FLAGS\n",
    "flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')\n",
    "flags.DEFINE_integer('epochs', 20, 'Number of epochs to train.')\n",
    "flags.DEFINE_integer('hidden1', 32, 'Number of units in hidden layer 1.')\n",
    "flags.DEFINE_integer('hidden2', 16, 'Number of units in hidden layer 2.')\n",
    "flags.DEFINE_float('dropout', 0.1, 'Dropout rate (1 - keep probability).')\n",
    "tf.app.flags.DEFINE_string('f', '', 'kernel')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Various Utility Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def load_data():\n",
    "    g = nx.read_edgelist('yeast.edgelist')\n",
    "    adj = nx.adjacency_matrix(g)\n",
    "    return adj\n",
    "\n",
    "\n",
    "def weight_variable_glorot(input_dim, output_dim, name=\"\"):\n",
    "    init_range = np.sqrt(6.0 / (input_dim + output_dim))\n",
    "    initial = tf.random_uniform(\n",
    "        [input_dim, output_dim], minval=-init_range,\n",
    "        maxval=init_range, dtype=tf.float32)\n",
    "    return tf.Variable(initial, name=name)\n",
    "\n",
    "\n",
    "def dropout_sparse(x, keep_prob, num_nonzero_elems):\n",
    "    noise_shape = [num_nonzero_elems]\n",
    "    random_tensor = keep_prob\n",
    "    random_tensor += tf.random_uniform(noise_shape)\n",
    "    dropout_mask = tf.cast(tf.floor(random_tensor), dtype=tf.bool)\n",
    "    pre_out = tf.sparse_retain(x, dropout_mask)\n",
    "    return pre_out * (1. / keep_prob)\n",
    "\n",
    "\n",
    "def sparse_to_tuple(sparse_mx):\n",
    "    if not sp.isspmatrix_coo(sparse_mx):\n",
    "        sparse_mx = sparse_mx.tocoo()\n",
    "    coords = np.vstack((sparse_mx.row, sparse_mx.col)).transpose()\n",
    "    values = sparse_mx.data\n",
    "    shape = sparse_mx.shape\n",
    "    return coords, values, shape\n",
    "\n",
    "\n",
    "def preprocess_graph(adj):\n",
    "    adj = sp.coo_matrix(adj)\n",
    "    adj_ = adj + sp.eye(adj.shape[0])\n",
    "    rowsum = np.array(adj_.sum(1))\n",
    "    degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten())\n",
    "    adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()\n",
    "    return sparse_to_tuple(adj_normalized)\n",
    "\n",
    "\n",
    "def construct_feed_dict(adj_normalized, adj, features, placeholders):\n",
    "    feed_dict = dict()\n",
    "    feed_dict.update({placeholders['features']: features})\n",
    "    feed_dict.update({placeholders['adj']: adj_normalized})\n",
    "    feed_dict.update({placeholders['adj_orig']: adj})\n",
    "    return feed_dict\n",
    "\n",
    "\n",
    "def mask_test_edges(adj):\n",
    "    # Function to build test set with 2% positive links\n",
    "    # Remove diagonal elements\n",
    "    adj = adj - sp.dia_matrix((adj.diagonal()[np.newaxis, :], [0]), shape=adj.shape)\n",
    "    adj.eliminate_zeros()\n",
    "\n",
    "    adj_triu = sp.triu(adj)\n",
    "    adj_tuple = sparse_to_tuple(adj_triu)\n",
    "    edges = adj_tuple[0]\n",
    "    edges_all = sparse_to_tuple(adj)[0]\n",
    "    num_test = int(np.floor(edges.shape[0] / 50.))\n",
    "    num_val = int(np.floor(edges.shape[0] / 50.))\n",
    "\n",
    "    all_edge_idx = range(edges.shape[0])\n",
    "    np.random.shuffle(all_edge_idx)\n",
    "    val_edge_idx = all_edge_idx[:num_val]\n",
    "    test_edge_idx = all_edge_idx[num_val:(num_val + num_test)]\n",
    "    test_edges = edges[test_edge_idx]\n",
    "    val_edges = edges[val_edge_idx]\n",
    "    train_edges = np.delete(edges, np.hstack([test_edge_idx, val_edge_idx]), axis=0)\n",
    "\n",
    "    def ismember(a, b):\n",
    "        rows_close = np.all((a - b[:, None]) == 0, axis=-1)\n",
    "        return np.any(rows_close)\n",
    "\n",
    "    test_edges_false = []\n",
    "    while len(test_edges_false) < len(test_edges):\n",
    "        n_rnd = len(test_edges) - len(test_edges_false)\n",
    "        rnd = np.random.randint(0, adj.shape[0], size=2 * n_rnd)\n",
    "        idxs_i = rnd[:n_rnd]                                        \n",
    "        idxs_j = rnd[n_rnd:]\n",
    "        for i in range(n_rnd):\n",
    "            idx_i = idxs_i[i]\n",
    "            idx_j = idxs_j[i]\n",
    "            if idx_i == idx_j:\n",
    "                continue\n",
    "            if ismember([idx_i, idx_j], edges_all):\n",
    "                continue\n",
    "            if test_edges_false:\n",
    "                if ismember([idx_j, idx_i], np.array(test_edges_false)):\n",
    "                    continue\n",
    "                if ismember([idx_i, idx_j], np.array(test_edges_false)):\n",
    "                    continue\n",
    "            test_edges_false.append([idx_i, idx_j])\n",
    "\n",
    "    val_edges_false = []\n",
    "    while len(val_edges_false) < len(val_edges):\n",
    "        n_rnd = len(val_edges) - len(val_edges_false)\n",
    "        rnd = np.random.randint(0, adj.shape[0], size=2 * n_rnd)\n",
    "        idxs_i = rnd[:n_rnd]                                        \n",
    "        idxs_j = rnd[n_rnd:]\n",
    "        for i in range(n_rnd):\n",
    "            idx_i = idxs_i[i]\n",
    "            idx_j = idxs_j[i]\n",
    "            if idx_i == idx_j:\n",
    "                continue\n",
    "            if ismember([idx_i, idx_j], train_edges):\n",
    "                continue\n",
    "            if ismember([idx_j, idx_i], train_edges):\n",
    "                continue\n",
    "            if ismember([idx_i, idx_j], val_edges):\n",
    "                continue\n",
    "            if ismember([idx_j, idx_i], val_edges):\n",
    "                continue\n",
    "            if val_edges_false:\n",
    "                if ismember([idx_j, idx_i], np.array(val_edges_false)):\n",
    "                    continue\n",
    "                if ismember([idx_i, idx_j], np.array(val_edges_false)):\n",
    "                    continue\n",
    "            val_edges_false.append([idx_i, idx_j])\n",
    "\n",
    "    # Re-build adj matrix\n",
    "    data = np.ones(train_edges.shape[0])\n",
    "    adj_train = sp.csr_matrix((data, (train_edges[:, 0], train_edges[:, 1])), shape=adj.shape)\n",
    "    adj_train = adj_train + adj_train.T\n",
    "\n",
    "    return adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false\n",
    "\n",
    "\n",
    "def get_roc_score(edges_pos, edges_neg):\n",
    "    feed_dict.update({placeholders['dropout']: 0})\n",
    "    emb = sess.run(model.embeddings, feed_dict=feed_dict)\n",
    "\n",
    "    def sigmoid(x):\n",
    "        return 1 / (1 + np.exp(-x))\n",
    "\n",
    "    # Predict on test set of edges\n",
    "    adj_rec = np.dot(emb, emb.T)\n",
    "    preds = []\n",
    "    pos = []\n",
    "    for e in edges_pos:\n",
    "        preds.append(sigmoid(adj_rec[e[0], e[1]]))\n",
    "        pos.append(adj_orig[e[0], e[1]])\n",
    "\n",
    "    preds_neg = []\n",
    "    neg = []\n",
    "    for e in edges_neg:\n",
    "        preds_neg.append(sigmoid(adj_rec[e[0], e[1]]))\n",
    "        neg.append(adj_orig[e[0], e[1]])\n",
    "\n",
    "    preds_all = np.hstack([preds, preds_neg])\n",
    "    labels_all = np.hstack([np.ones(len(preds)), np.zeros(len(preds))])\n",
    "    roc_score = roc_auc_score(labels_all, preds_all)\n",
    "    ap_score = average_precision_score(labels_all, preds_all)\n",
    "\n",
    "    return roc_score, ap_score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Convolutional Layers for our GCN Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class GraphConvolution():\n",
    "    \"\"\"Basic graph convolution layer for undirected graph without edge labels.\"\"\"\n",
    "    def __init__(self, input_dim, output_dim, adj, name, dropout=0., act=tf.nn.relu):\n",
    "        self.name = name\n",
    "        self.vars = {}\n",
    "        self.issparse = False\n",
    "        with tf.variable_scope(self.name + '_vars'):\n",
    "            self.vars['weights'] = weight_variable_glorot(input_dim, output_dim, name='weights')\n",
    "        self.dropout = dropout\n",
    "        self.adj = adj\n",
    "        self.act = act\n",
    "\n",
    "    def __call__(self, inputs):\n",
    "        with tf.name_scope(self.name):        \n",
    "            x = inputs\n",
    "            x = tf.nn.dropout(x, 1-self.dropout)\n",
    "            x = tf.matmul(x, self.vars['weights'])\n",
    "            x = tf.sparse_tensor_dense_matmul(self.adj, x)\n",
    "            outputs = self.act(x)\n",
    "        return outputs\n",
    "\n",
    "\n",
    "class GraphConvolutionSparse():\n",
    "    \"\"\"Graph convolution layer for sparse inputs.\"\"\"\n",
    "    def __init__(self, input_dim, output_dim, adj, features_nonzero, name, dropout=0., act=tf.nn.relu):\n",
    "        self.name = name\n",
    "        self.vars = {}\n",
    "        self.issparse = False\n",
    "        with tf.variable_scope(self.name + '_vars'):\n",
    "            self.vars['weights'] = weight_variable_glorot(input_dim, output_dim, name='weights')\n",
    "        self.dropout = dropout\n",
    "        self.adj = adj\n",
    "        self.act = act\n",
    "        self.issparse = True\n",
    "        self.features_nonzero = features_nonzero\n",
    "\n",
    "    def __call__(self, inputs):\n",
    "        with tf.name_scope(self.name):\n",
    "            x = inputs\n",
    "            x = dropout_sparse(x, 1-self.dropout, self.features_nonzero)\n",
    "            x = tf.sparse_tensor_dense_matmul(x, self.vars['weights'])\n",
    "            x = tf.sparse_tensor_dense_matmul(self.adj, x)\n",
    "            outputs = self.act(x)\n",
    "        return outputs\n",
    "    \n",
    "    \n",
    "class InnerProductDecoder():\n",
    "    \"\"\"Decoder model layer for link prediction.\"\"\"\n",
    "    def __init__(self, input_dim, name, dropout=0., act=tf.nn.sigmoid):\n",
    "        self.name = name\n",
    "        self.issparse = False\n",
    "        self.dropout = dropout\n",
    "        self.act = act\n",
    "\n",
    "    def __call__(self, inputs):\n",
    "        with tf.name_scope(self.name):\n",
    "            inputs = tf.nn.dropout(inputs, 1-self.dropout)\n",
    "            x = tf.transpose(inputs)\n",
    "            x = tf.matmul(inputs, x)\n",
    "            x = tf.reshape(x, [-1])\n",
    "            outputs = self.act(x)\n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Specify the Architecture of our GCN Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class GCNModel():\n",
    "    def __init__(self, placeholders, num_features, features_nonzero, name):\n",
    "        self.name = name\n",
    "        self.inputs = placeholders['features']\n",
    "        self.input_dim = num_features\n",
    "        self.features_nonzero = features_nonzero\n",
    "        self.adj = placeholders['adj']\n",
    "        self.dropout = placeholders['dropout']\n",
    "        with tf.variable_scope(self.name):\n",
    "            self.build()\n",
    "        \n",
    "    def build(self):\n",
    "        self.hidden1 = GraphConvolutionSparse(\n",
    "            name='gcn_sparse_layer',\n",
    "            input_dim=self.input_dim,\n",
    "            output_dim=FLAGS.hidden1,\n",
    "            adj=self.adj,\n",
    "            features_nonzero=self.features_nonzero,\n",
    "            act=tf.nn.relu,\n",
    "            dropout=self.dropout)(self.inputs)\n",
    "\n",
    "        self.embeddings = GraphConvolution(\n",
    "            name='gcn_dense_layer',\n",
    "            input_dim=FLAGS.hidden1,\n",
    "            output_dim=FLAGS.hidden2,\n",
    "            adj=self.adj,\n",
    "            act=lambda x: x,\n",
    "            dropout=self.dropout)(self.hidden1)\n",
    "\n",
    "        self.reconstructions = InnerProductDecoder(\n",
    "            name='gcn_decoder',\n",
    "            input_dim=FLAGS.hidden2, \n",
    "            act=lambda x: x)(self.embeddings)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Specify the GCN Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Optimizer():\n",
    "    def __init__(self, preds, labels, num_nodes, num_edges):\n",
    "        pos_weight = float(num_nodes**2 - num_edges) / num_edges\n",
    "        norm = num_nodes**2 / float((num_nodes**2 - num_edges) * 2)\n",
    "        \n",
    "        preds_sub = preds\n",
    "        labels_sub = labels\n",
    "\n",
    "        self.cost = norm * tf.reduce_mean(\n",
    "            tf.nn.weighted_cross_entropy_with_logits(\n",
    "                logits=preds_sub, targets=labels_sub, pos_weight=pos_weight))\n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)  # Adam Optimizer\n",
    "\n",
    "        self.opt_op = self.optimizer.minimize(self.cost)\n",
    "        self.grads_vars = self.optimizer.compute_gradients(self.cost)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train the GCN Model and Evaluate its Accuracy on a Test Set of Protein-Protein Interactions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Given a training set of protein-protein interactions in yeast *S. cerevisiae*, our goal is to take these interactions and train a GCN model that can predict new protein-protein interactions. That is, we would like to predict new edges in the yeast protein interaction network. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "adj = load_data()\n",
    "num_nodes = adj.shape[0]\n",
    "num_edges = adj.sum()\n",
    "# Featureless\n",
    "features = sparse_to_tuple(sp.identity(num_nodes))\n",
    "num_features = features[2][1]\n",
    "features_nonzero = features[1].shape[0]\n",
    "\n",
    "# Store original adjacency matrix (without diagonal entries) for later\n",
    "adj_orig = adj - sp.dia_matrix((adj.diagonal()[np.newaxis, :], [0]), shape=adj.shape)\n",
    "adj_orig.eliminate_zeros()\n",
    "\n",
    "adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false = mask_test_edges(adj)\n",
    "adj = adj_train\n",
    "\n",
    "adj_norm = preprocess_graph(adj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Define placeholders\n",
    "placeholders = {\n",
    "    'features': tf.sparse_placeholder(tf.float32),\n",
    "    'adj': tf.sparse_placeholder(tf.float32),\n",
    "    'adj_orig': tf.sparse_placeholder(tf.float32),\n",
    "    'dropout': tf.placeholder_with_default(0., shape=())\n",
    "}\n",
    "\n",
    "# Create model\n",
    "model = GCNModel(placeholders, num_features, features_nonzero, name='yeast_gcn')\n",
    "\n",
    "# Create optimizer\n",
    "with tf.name_scope('optimizer'):\n",
    "    opt = Optimizer(\n",
    "        preds=model.reconstructions,\n",
    "        labels=tf.reshape(tf.sparse_tensor_to_dense(placeholders['adj_orig'], validate_indices=False), [-1]),\n",
    "        num_nodes=num_nodes,\n",
    "        num_edges=num_edges)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0001 train_loss= 0.68120 val_roc= 0.83277 val_ap= 0.80579 time= 1.50844\n",
      "Epoch: 0002 train_loss= 0.68119 val_roc= 0.87144 val_ap= 0.86364 time= 1.29321\n",
      "Epoch: 0003 train_loss= 0.68111 val_roc= 0.87741 val_ap= 0.86811 time= 1.05034\n",
      "Epoch: 0004 train_loss= 0.68066 val_roc= 0.87794 val_ap= 0.86836 time= 1.12609\n",
      "Epoch: 0005 train_loss= 0.67928 val_roc= 0.87799 val_ap= 0.86839 time= 1.12086\n",
      "Epoch: 0006 train_loss= 0.67651 val_roc= 0.87802 val_ap= 0.86841 time= 1.19530\n",
      "Epoch: 0007 train_loss= 0.67130 val_roc= 0.87802 val_ap= 0.86840 time= 1.21475\n",
      "Epoch: 0008 train_loss= 0.66344 val_roc= 0.87802 val_ap= 0.86840 time= 1.29680\n",
      "Epoch: 0009 train_loss= 0.65169 val_roc= 0.87802 val_ap= 0.86840 time= 1.14988\n",
      "Epoch: 0010 train_loss= 0.63824 val_roc= 0.87802 val_ap= 0.86839 time= 1.24414\n",
      "Epoch: 0011 train_loss= 0.62303 val_roc= 0.87802 val_ap= 0.86838 time= 1.19439\n",
      "Epoch: 0012 train_loss= 0.61325 val_roc= 0.87800 val_ap= 0.86838 time= 1.27773\n",
      "Epoch: 0013 train_loss= 0.61614 val_roc= 0.87796 val_ap= 0.86837 time= 1.29762\n",
      "Epoch: 0014 train_loss= 0.62800 val_roc= 0.87792 val_ap= 0.86835 time= 1.22622\n",
      "Epoch: 0015 train_loss= 0.62828 val_roc= 0.87786 val_ap= 0.86833 time= 1.17643\n",
      "Epoch: 0016 train_loss= 0.62178 val_roc= 0.87776 val_ap= 0.86828 time= 1.19120\n",
      "Epoch: 0017 train_loss= 0.61396 val_roc= 0.87763 val_ap= 0.86822 time= 1.38739\n",
      "Epoch: 0018 train_loss= 0.60946 val_roc= 0.87744 val_ap= 0.86811 time= 1.33678\n",
      "Epoch: 0019 train_loss= 0.61013 val_roc= 0.87721 val_ap= 0.86799 time= 1.26785\n",
      "Epoch: 0020 train_loss= 0.61154 val_roc= 0.87702 val_ap= 0.86787 time= 1.32318\n",
      "Optimization Finished!\n",
      "Test ROC score: 0.87898\n",
      "Test AP score: 0.86944\n"
     ]
    }
   ],
   "source": [
    "# Initialize session\n",
    "sess = tf.Session()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "\n",
    "adj_label = adj_train + sp.eye(adj_train.shape[0])\n",
    "adj_label = sparse_to_tuple(adj_label)\n",
    "\n",
    "# Train model\n",
    "for epoch in range(FLAGS.epochs):\n",
    "    t = time.time()\n",
    "    # Construct feed dictionary\n",
    "    feed_dict = construct_feed_dict(adj_norm, adj_label, features, placeholders)\n",
    "    feed_dict.update({placeholders['dropout']: FLAGS.dropout})\n",
    "    # One update of parameter matrices\n",
    "    _, avg_cost = sess.run([opt.opt_op, opt.cost], feed_dict=feed_dict)\n",
    "    # Performance on validation set\n",
    "    roc_curr, ap_curr = get_roc_score(val_edges, val_edges_false)\n",
    "\n",
    "    print(\"Epoch:\", '%04d' % (epoch + 1), \n",
    "          \"train_loss=\", \"{:.5f}\".format(avg_cost),\n",
    "          \"val_roc=\", \"{:.5f}\".format(roc_curr),\n",
    "          \"val_ap=\", \"{:.5f}\".format(ap_curr),\n",
    "          \"time=\", \"{:.5f}\".format(time.time() - t))\n",
    "\n",
    "print('Optimization Finished!')\n",
    "\n",
    "roc_score, ap_score = get_roc_score(test_edges, test_edges_false)\n",
    "print('Test ROC score: {:.5f}'.format(roc_score))\n",
    "print('Test AP score: {:.5f}'.format(ap_score))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
